[BugFix][Mamba] Fix da_cumsum kernel to support dt_bias, softplus, and clamp#1118
[BugFix][Mamba] Fix da_cumsum kernel to support dt_bias, softplus, and clamp#1118Ibuki-wind merged 5 commits intotile-ai:mainfrom
Conversation
…lamp; update chunk_state, tests, and benchmarks Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…-flag test cases - Remove is_cuda/dtype/has_dt_bias validation from DaCumsumFwdKernel.forward — per project pattern, validation belongs only in Op.forward (system boundary) - Add bias-only (True,False) and softplus-only (False,True) smoke cases to DaCumsumFwdFixture; each guards an independent compile-time branch in the kernel Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request updates the da_cumsum operator and kernel to support optional per-head bias, softplus activation, and clamping, aligning the implementation with the Mamba-2 Triton reference. Additionally, the ssd_chunk_state kernel is optimized by reordering the fused axis decoding to improve L2 cache reuse and by eliminating intermediate register fragments for scaled inputs, which allows for larger tile configurations. Feedback was provided regarding the da_cumsum kernel's autotune configuration, noting that since the scan is implemented serially within blocks, using multiple threads leads to redundant work and hardware contention.
The inner scan is T.serial(Q): every thread in the block executes the
same loop and writes to the same output locations. Configs with threads
> 1 produce redundant work and write contention with no throughput
benefit. Remove {32, 64, 128} from the autotune search space.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Ibuki-wind
left a comment
There was a problem hiding this comment.
Overall
One correctness blocker remains; address the inline comment and rerun the affected Mamba tests.
…lru_cache, chunk_scan load order
…e configs - Swap state_tile load loop from (nn,pp) to (pp,nn) so consecutive threads iterate over the contiguous N dimension, giving coalesced 128-byte global loads instead of strided-by-N accesses. - Expand autotune_configs: block_n [16,32] -> [32,64,128], block_s [32,64] -> [64,128] to cover larger tile sizes used by H200.
Summary
da_cumsum
Accepts raw
dtand applies the full pipeline as compile-time-conditional steps:dt_bias(has_dt_bias=True)dt > 20(dt_softplus=True)[dt_min, dt_max]dA = dt_out * AReturns two outputs:
dt_out(processed dt) anddA_cumsum. The kernelsignature is fixed regardless of flags; unused inputs are dummy-zeroed at the
op boundary only when
has_dt_bias=False— callingforwardwithdt_bias=Nonewhenhas_dt_bias=Truenow raisesValueErrorimmediatelyinstead of silently computing un-biased results.
ssd_chunk_scan
Tensor layouts updated to match the official
_chunk_scan_fwd:x,C,outchanged from chunk-fused[B,C,L,H,P]to seqlen-fused[B,S,H,P]cbchanged from head-owned[B,C,H,L,L]to group-owned[B,C,G,L,L]prev_statesaxis order changed from[B,C,H,N,P]to[B,C,H,P,N](P before N, official convention)dtlayout changed from[B,C,L,H]to[B,H,C,L]n_groupsadded as a constructor parameterdA_lshared-memory load moved to just before it is consumed (after thehistory path), eliminating a redundant
sync_threadsstall.ssd_state_passing
Fixed output convention to match Mamba-2 spec:
out[:,c]now holds the statebefore chunk
c, soout[:,0] = initial_statesandout[:,c+1] = s_cforcin[0, C-2]. Reference implementations in tests and benchmarks updatedto match.
All Mamba kernels
Added
@functools.lru_cacheto all five kernel factory functions(
da_cumsum,ssd_chunk_scan,ssd_chunk_state,ssd_decode,ssd_state_passing) to prevent redundant TileLang recompilation on repeatedcalls with identical static parameters.
Test plan
tests/ops/test_mamba.py— all existing and new test cases pass(includes new smoke test
test_da_cumsum_fwd_missing_bias_raises)benchmarks/ops/bench_mamba.py— benchmark executes cleanly with noregressions